其他
kmeans算法python代码——可直接运行
在安装了相应依赖包情况下,以下代码可直接运行。
1# -*- coding:utf-8 -*-
2
3import numpy as np
4import random as rd
5import matplotlib.pyplot as plt
6import math
7
8def printLine():
9 print '----------------------------------------------------------------------------'
10
11#计算聚类中心
12def cent(x):
13 return(sum(x)/len(x))
14
15#距离, 返回s,C,分别是距离平方和与聚类方案
16def f(center):
17 # c0 = []
18 # c1 = []
19 # c2 = []
20 c = [[] for i in range(k)]
21 D = np.arange(k*n).reshape(k,n)
22 d = np.array([center[i]-dat.T for i in range(k)])
23 for i in range(k):
24 D[i] = sum((d[i]**2).T)
25 for i in range(n):
26 ind = D.T[i].argmin()
27 c[ind].append(i)
28 C = [np.array([dat.T[i] for i in j]) for j in c]
29 print(c)
30 s = 0
31 for i in C:
32 s+=dist(i)
33 return(s,C)
34
35#计算各点到聚类中心的距离之和
36def dist(x):
37 #聚类中心
38 m0 = cent(x)
39 dis = sum(sum((x-m0)**2))
40 return dis
41
42def run():
43 # 存储距离和
44 s_sum = []
45 #---随机产生聚类中心----#
46 center = rd.sample(range(n),k)
47 center = np.array([dat.T[i] for i in center])
48 print '初始化聚类中心为:'.decode('utf-8')
49 print(center)
50 printLine()
51 #初始距离和
52 print '第1次计算!'.decode('utf-8')
53 dd,C = f(center)
54 s_sum.append(dd)
55 print ('距离和为'+str(dd)).decode('utf-8')
56 printLine()
57 print('第2次计算!'.decode('utf-8'))
58 center = [cent(i) for i in C]
59 Dd,C = f(center)
60 s_sum.append(Dd)
61 print ('距离和为'+str(Dd)).decode('utf-8')
62 # 前面已经计算2次了,所以这里从第三次开始计算
63 K = 3
64 while(K<n_max):
65 printLine()
66 #两次差值很小并且计算了一定次数
67 if(math.sqrt(abs(dd-Dd)) < 0 and K>20):
68 break;
69 print ('第'+str(K)+'次计算!').decode('utf-8')
70 dd = Dd
71 print ('距离和为'+str(dd)).decode('utf-8')
72 #当前聚类中心
73 center = [cent(i) for i in C]
74 Dd,C = f(center)
75 s_sum.append(Dd)
76 K+=1
77
78 #-----------------聚类结果可视化部分--------------------#
79 j = 0
80 for i in C:
81 if(j == 0):
82 plt.plot(i.T[0],i.T[1],'ro')
83 if(j == 1):
84 plt.plot(i.T[0],i.T[1],'b+')
85 if(j == 2):
86 plt.plot(i.T[0],i.T[1],'g*')
87 if(j == 3):
88 plt.plot(i.T[0],i.T[1], 'c<')
89 j+=1
90 plt.show()
91 x = range(len(s_sum))
92 plt.plot(x, s_sum)
93 plt.plot(x, s_sum, 'ro')
94 plt.show()
95
96
97print '==============================================================================='
98#数据
99dat = np.array([[14,22,15,20,30,20,32,13,23,20,21,22,23,24,35,18,20,31,14]
100 ,[15,28,18,30,35,15,30,15,25,23,24,25,26,27,30,15,24,33,12]])
101dat = np.random.randint(0, 30, (2, 40))
102print(dat)
103#=========================聚类中心======================#
104n = len(dat[0])
105N = len(dat)*n
106k = 4
107n_max = 50
108
109# 程序入口
110if __name__ == '__main__':
111 print '==============================================================================='
112 run()
可以通过修改k和n_max的值,改变聚类数量和测试样本数量。
点击【阅读原文】购买python spark大数据课程,限时优惠。